package com.yifeng.repo.controller.traffic.in.processor;

import com.sun.tools.javac.api.JavacTrees;
import com.sun.tools.javac.processing.JavacProcessingEnvironment;
import com.sun.tools.javac.tree.JCTree;
import com.sun.tools.javac.tree.TreeMaker;
import com.sun.tools.javac.tree.TreeTranslator;
import com.sun.tools.javac.util.*;
import com.yifeng.repo.controller.traffic.in.processor.manager.TrafficInStatsManager;
import com.yifeng.repo.controller.traffic.apt.APT;
import com.yifeng.repo.controller.traffic.in.TrafficIn;

import javax.annotation.processing.*;
import javax.lang.model.SourceVersion;
import javax.lang.model.element.Element;
import javax.lang.model.element.ElementKind;
import javax.lang.model.element.TypeElement;
import javax.tools.Diagnostic;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;

/**
 * Created by daibing on 2021/12/3.
 */
@SupportedAnnotationTypes("com.yifeng.repo.controller.traffic.in.TrafficIn")
public class TrafficInProcessor extends AbstractProcessor {
    private Messager messager;
    private JavacTrees javacTrees;
    private TreeMaker treeMaker;
    private Names names;
    private APT apt;

    @Override
    public SourceVersion getSupportedSourceVersion() {
        return SourceVersion.latest();
    }

    @Override
    public synchronized void init(ProcessingEnvironment processingEnv) {
        super.init(processingEnv);
        messager = processingEnv.getMessager();
        javacTrees = JavacTrees.instance(processingEnv);
        Context context = ((JavacProcessingEnvironment) processingEnv).getContext();
        treeMaker = TreeMaker.instance(context);
        names = Names.instance(context);
        apt = new APT(treeMaker, names);
    }

    @Override
    public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment roundEnv) {
        if (roundEnv.processingOver()) {
            return false;
        }
        for (TypeElement annotation : annotations) {
            Set<? extends Element> elements = roundEnv.getElementsAnnotatedWith(annotation);
            for (Element element : elements) {
                if (element.getKind() != ElementKind.METHOD) {
                    messager.printMessage(Diagnostic.Kind.ERROR, "Only method can be annotated by traffic in!");
                    continue;
                }
                TrafficIn traffic = element.getAnnotation(TrafficIn.class);
                Element enclosingElement = element.getEnclosingElement();

                JCTree methodTree = javacTrees.getTree(element);
                methodTree.accept(new TreeTranslator() {
                    @Override
                    public void visitMethodDef(JCTree.JCMethodDecl jcMethodDecl) {
                        // skip method of inner class
                        if (apt.containAnnotation(jcMethodDecl.getModifiers().getAnnotations(), TrafficIn.class)) {
                            jcMethodDecl.body.stats = buildFullStatements(traffic, enclosingElement.getSimpleName().toString(), jcMethodDecl);
                        }
                        messager.printMessage(Diagnostic.Kind.NOTE, jcMethodDecl.getName() + " has been processed!");
                        super.visitMethodDef(jcMethodDecl);
                    }
                });
            }
        }
        return true;
    }

    /**
     * 构建完整方法代码：
     * 1、首先必须在方法加注解，这样方法内部的注解才生效
     * 2、方法上的注解支持根据策略编号检查许可，也支持根据方法上下文检查许可，方法内部的注解仅支持根据策略编号检查许可
     * 3、方法上的注解在方法体第一行加检查许可代码，方法内部的注解在注解前面插入检查许可代码
     */
    private List<JCTree.JCStatement> buildFullStatements(TrafficIn traffic, String clazz, JCTree.JCMethodDecl jcMethodDecl) {
        // 1. source statements of method
        List<JCTree.JCStatement> sourceStatements = jcMethodDecl.getBody().getStatements();
        ListBuffer<JCTree.JCStatement> targetStatements = new ListBuffer<>();

        // 2. check trafficIn of declare of method
        if (traffic.checkPermit()) {
            targetStatements.addAll(buildTrafficStatements(traffic.policyCode(), clazz, jcMethodDecl));
        }

        // 3. check trafficIn of inside method body
        for (JCTree.JCStatement sourceStatement : sourceStatements) {
            // 不是局部变量直接返回
            if (!(sourceStatement instanceof JCTree.JCVariableDecl)) {
                targetStatements.add(sourceStatement);
                continue;
            }
            // 不带TrafficIn注解直接返回
            JCTree.JCAnnotation annotation = apt.getAnnotation(((JCTree.JCVariableDecl) sourceStatement).mods.getAnnotations(), TrafficIn.class);
            if (annotation == null) {
                targetStatements.append(sourceStatement);
                continue;
            }
            // 获取内部注解参数，必须明确策略编号
            List<JCTree.JCExpression> arguments = annotation.getArguments();
            Map<String, String> key2Value = new HashMap<>(arguments.size());
            for (JCTree.JCExpression argument : arguments) {
                JCTree.JCAssign assign = (JCTree.JCAssign) argument;
                key2Value.put(assign.getVariable().toString(), trimDoubleQuotationMark(assign.getExpression().toString()));
            }
            if (isBlank(key2Value.get("policyCode")) || !Boolean.parseBoolean(key2Value.get("checkPermit"))) {
                targetStatements.append(sourceStatement);
                continue;
            }
            targetStatements.addAll(buildTrafficStatements(key2Value.get("policyCode"), clazz, jcMethodDecl));
            targetStatements.append(sourceStatement);
        }
        return targetStatements.toList();
    }

    /**
     * TrafficInStatsManager.checkPermit("9999333ddd");
     * TrafficInStatsManager.checkPermit("EndpointServiceImpl", "pushStock", 6, argMap);
     */
    private List<JCTree.JCStatement> buildTrafficStatements(String policyCode, String clazz, JCTree.JCMethodDecl jcMethodDecl) {
        // check permit by policy code
        if (!isBlank(policyCode)) {
            return List.of(treeMaker.Exec(
                    treeMaker.Apply(
                            List.nil(),
                            apt.memberAccess(TrafficInStatsManager.class.getName() + "." + "checkPermit"),
                            List.of(treeMaker.Literal(policyCode))
                    )
            ));
        }

        // check permit by method context
        List<String> argNames = apt.getArgNames(jcMethodDecl.getParameters());
        Name argMapInstance = names.fromString("argMap");
        ListBuffer<JCTree.JCExpression> exprList = new ListBuffer<>();
        exprList.add(treeMaker.Literal(clazz));
        exprList.add(treeMaker.Literal(jcMethodDecl.getName().toString()));
        exprList.add(treeMaker.Literal(argNames.size()));
        exprList.add(treeMaker.Ident(argMapInstance));
        ListBuffer<JCTree.JCStatement> statements = new ListBuffer<>();
        statements.addAll(this.buildArgMap(argMapInstance, argNames));
        statements.add(treeMaker.Exec(
                treeMaker.Apply(
                        List.nil(),
                        apt.memberAccess(TrafficInStatsManager.class.getName() + "." + "checkPermit"),
                        exprList.toList()
                )
        ));
        return statements.toList();
    }

    /**
     * 注意：argMap字段顺序不能混淆，函数调用的时候根据该顺序来构造参数
     * Map<String, Object> argMap = new LinkedHashMap<>();
     * argMap.put("batchNo", batchNo);
     * argMap.put("xxx", xxx);
     */
    private List<JCTree.JCStatement> buildArgMap(Name argMapInstance, List<String> argNames) {
        ListBuffer<JCTree.JCStatement> statements = new ListBuffer<>();
        statements.add(treeMaker.VarDef(
                treeMaker.Modifiers(0),
                argMapInstance,
                apt.memberAccess(Map.class.getName(), String.class.getName(), Object.class.getName()),
                treeMaker.NewClass(null, List.nil(), apt.memberAccess(LinkedHashMap.class.getName()), List.nil(),null)
        ));
        for (String argName : argNames) {
            statements.add(apt.methodByMap(argMapInstance.toString(), "put", argName, argName));
        }
        return statements.toList();
    }

    private static String trimDoubleQuotationMark(String s) {
        StringBuilder sb = new StringBuilder(s);
        if (s.startsWith("\"")) {
            sb.deleteCharAt(0);
        }
        if (s.endsWith("\"")) {
            sb.deleteCharAt(sb.length() - 1);
        }
        return sb.toString();
    }

    private static boolean isBlank(String s) {
        return s == null || s.trim().length() == 0;
    }

}
